import os
from copy import copy

import dgl
import numpy as np
from rdkit import Chem
from torch_geometric.data import Batch

from chem.loader import MoleculeDataset, read_book
from mol import smiles2graph


def transform_graph(data_path, train_ratio=0.6, val_ratio=0.2, test_ratio=0.2):
    # 定义训练、验证和测试比例
    files = open(data_path, 'r', encoding='utf-8')
    smiles = files.readlines()
    # 将SMILES字符串转换为原子-键图（atom-bond graph）格式
    dataset = []
    for s in smiles:
        data = smiles2graph(s)
        dataset.append(data)
    # # 分割数据集并创建训练、验证和测试掩码
    # num_graphs = len(dataset)
    # num_train = int(num_graphs * train_ratio)
    # num_val = int(num_graphs * val_ratio)
    # num_test = num_graphs - num_train - num_val
    #
    # idx = np.random.permutation(num_graphs)
    # train_idx = idx[:num_train]
    # val_idx = idx[num_train:num_train + num_val]
    # test_idx = idx[num_train + num_val:]
    #
    # train_mask = torch.zeros(num_graphs, dtype=torch.bool)
    # train_mask[train_idx] = True
    #
    # val_mask = torch.zeros(num_graphs, dtype=torch.bool)
    # val_mask[val_idx] = True
    #
    # test_mask = torch.zeros(num_graphs, dtype=torch.bool)
    # test_mask[test_idx] = True
    # batch = Batch.from_data_list(dataset)
    return dataset


from torch_geometric.data import InMemoryDataset, Data
import torch


class CustomDataset(InMemoryDataset):
    def __init__(self, data_list, root, transform=None, pre_transform=None):
        super(CustomDataset, self).__init__(root, transform, pre_transform)
        self.data_list = data_list
        self.data, self.slices = self.collate(data_list)
        self.process()

    def process(self):
        # 进行数据集预处理，例如对节点和边特征向量进行编码、对数据集进行划分等操作
        pass

    def __len__(self):
        return len(self.data_list)

    def get(self, idx):
        return self.data[idx]

    @property
    def raw_file_names(self):
        # 如果需要从外部加载原始数据，可以在这里指定原始数据文件的名称
        return []

    @property
    def processed_file_names(self):
        # 指定处理后的数据文件名称
        return ['new_data.pt']

    def download(self):
        # 如果需要从网络上下载原始数据，可以在这里实现下载逻辑
        pass

    def process(self):
        # 进行数据集预处理，例如对节点和边特征向量进行编码、对数据集进行划分等操作
        pass

    def save(self):
        # 将数据集保存到磁盘
        torch.save(self.data_list, self.processed_paths[0])

    def load(self):
        # 从磁盘中加载数据集
        self.data_list = torch.load(self.processed_paths[0])
        self.data, self.slices = self.collate(self.data_list)


def read_data(dataset_name):
    dataset = MoleculeDataset("../dataset/" + dataset_name, dataset=dataset_name)
    cli_path = "../dataset/" + dataset_name + "/processed/clique_dict.txt"
    edge_path = "../dataset/" + dataset_name + "/processed/edge_dict.txt"
    clique_dict, all_edges = read_book(cli_path, edge_path)
    new_dataset = []
    for idx, i in enumerate(dataset):
        x = copy(i)
        x.clique_index = clique_dict[idx]
        x.clique_edges = all_edges[idx]
        new_dataset.append(x)
    return new_dataset

from ogb.graphproppred.mol_encoder import AtomEncoder

if __name__ == '__main__':
    # graphs = transform_graph('./data/tox21.txt')
    graphs=read_data('bace')
    dataset = CustomDataset(graphs, root=f'./dataset/bace/processed/')
    dataset.save()
    # 重新加载数据集
    # dataset = CustomDataset(root='./dataset/bace/processed/').load()
    # dataset.load()

    # data_list = torch.load('../dataset/bace/processed/geometric_data_processed.pt')
    # print(type(data_list))
    # for i in data_list:
    #     atom_enc = AtomEncoder(4)
    #     i.x=atom_enc(i.x)
    # data = Batch.from_data_list(data_list)
    # # batch = Batch.from_data_list(data)

    # print(data)
